//go:build linux
// +build linux

package netlink

import (
	"bytes"
	"crypto/rand"
	"encoding/hex"
	"fmt"
	"io/ioutil"
	"log"
	"os"
	"os/exec"
	"runtime"
	"strings"
	"testing"

	"github.com/vishvananda/netlink/nl"
	"github.com/vishvananda/netns"
	"golang.org/x/sys/unix"
)

type tearDownNetlinkTest func()

func skipUnlessRoot(t testing.TB) {
	t.Helper()

	if os.Getuid() != 0 {
		t.Skip("Test requires root privileges.")
	}
}

func skipUnlessKModuleLoaded(t *testing.T, moduleNames ...string) {
	t.Helper()
	file, err := ioutil.ReadFile("/proc/modules")
	if err != nil {
		t.Fatal("Failed to open /proc/modules", err)
	}

	foundRequiredMods := make(map[string]bool)
	lines := strings.Split(string(file), "\n")

	for _, name := range moduleNames {
		foundRequiredMods[name] = false
		for _, line := range lines {
			n := strings.Split(line, " ")[0]
			if n == name {
				foundRequiredMods[name] = true
				break
			}
		}
	}

	failed := false
	for _, name := range moduleNames {
		if found, _ := foundRequiredMods[name]; !found {
			t.Logf("Test requires missing kmodule %q.", name)
			failed = true
		}
	}
	if failed {
		t.SkipNow()
	}
}

func setUpNetlinkTest(t testing.TB) tearDownNetlinkTest {
	skipUnlessRoot(t)
	// Lock the OS thread, then record original namespace
	runtime.LockOSThread()
	origNS, err := netns.Get()
	if err != nil {
		runtime.UnlockOSThread()
		t.Fatal("Failed to get current namespace:", err)
	}
	// Create and enter a fresh namespace
	ns, err := netns.New()
	if err != nil {
		// attempt to restore before failing
		_ = netns.Set(origNS)
		runtime.UnlockOSThread()
		t.Fatal("Failed to create new namespace:", err)
	}
	// Reinitialize the package-level handle in this namespace
	if pkgHandle != nil {
		// ensure all sockets from the previous Handle are closed
		_ = pkgHandle.Close()
	}
	pkgHandle = &Handle{}

	return func() {
		// Close the new namespace handle
		ns.Close()
		// Restore the original namespace
		if err := netns.Set(origNS); err != nil {
			t.Fatalf("Failed to restore original namespace: %v", err)
		}
		_ = origNS.Close()
		// Unlock the OS thread
		runtime.UnlockOSThread()
	}
}

// setUpNamedNetlinkTest create a temporary named names space with a random name
func setUpNamedNetlinkTest(t *testing.T) (string, tearDownNetlinkTest) {
	skipUnlessRoot(t)

	origNS, err := netns.Get()
	if err != nil {
		t.Fatal("Failed saving orig namespace")
	}

	// create a random name
	rnd := make([]byte, 4)
	if _, err := rand.Read(rnd); err != nil {
		t.Fatal("failed creating random ns name")
	}
	name := "netlinktest-" + hex.EncodeToString(rnd)

	ns, err := netns.NewNamed(name)
	if err != nil {
		t.Fatal("Failed to create new ns", err)
	}

	runtime.LockOSThread()
	cleanup := func() {
		ns.Close()
		netns.DeleteNamed(name)
		netns.Set(origNS)
		runtime.UnlockOSThread()
	}

	if err := netns.Set(ns); err != nil {
		cleanup()
		t.Fatal("Failed entering new namespace", err)
	}

	return name, cleanup
}

func setUpNetlinkTestWithLoopback(t *testing.T) tearDownNetlinkTest {
	skipUnlessRoot(t)

	runtime.LockOSThread()

	// Save the current namespace
	origNS, err := netns.Get()
	if err != nil {
		runtime.UnlockOSThread()
		t.Fatal("Failed to get current namespace:", err)
	}

	// Create and enter a fresh namespace
	ns, err := netns.New()
	if err != nil {
		runtime.UnlockOSThread()
		t.Fatal("Failed to create new netns:", err)
	}

	// Bring up the loopback interface
	link, err := LinkByName("lo")
	if err != nil {
		t.Fatalf("Failed to find \"lo\" in new netns: %v", err)
	}
	if err := LinkSetUp(link); err != nil {
		t.Fatalf("Failed to bring up \"lo\" in new netns: %v", err)
	}

	// Teardown: restore original namespace and thread state
	return func() {
		ns.Close()
		if err := netns.Set(origNS); err != nil {
			t.Fatalf("Failed to restore original namespace: %v", err)
		}
		_ = origNS.Close()
		runtime.UnlockOSThread()
	}
}

func setUpF(t *testing.T, path, value string) {
	file, err := os.Create(path)
	if err != nil {
		t.Fatalf("Failed to open %s: %s", path, err)
	}
	defer file.Close()
	file.WriteString(value)
}

func setUpMPLSNetlinkTest(t *testing.T) tearDownNetlinkTest {
	if _, err := os.Stat("/proc/sys/net/mpls/platform_labels"); err != nil {
		t.Skip("Test requires MPLS support.")
	}
	f := setUpNetlinkTest(t)
	setUpF(t, "/proc/sys/net/mpls/platform_labels", "1024")
	setUpF(t, "/proc/sys/net/mpls/conf/lo/input", "1")
	return f
}

func setUpSEG6NetlinkTest(t *testing.T) tearDownNetlinkTest {
	// check if SEG6 options are enabled in Kernel Config
	cmd := exec.Command("uname", "-r")
	var out bytes.Buffer
	cmd.Stdout = &out
	if err := cmd.Run(); err != nil {
		t.Fatal("Failed to run: uname -r")
	}
	s := []string{"/boot/config-", strings.TrimRight(out.String(), "\n")}
	filename := strings.Join(s, "")

	grepKey := func(key, fname string) (string, error) {
		cmd := exec.Command("grep", key, filename)
		var out bytes.Buffer
		cmd.Stdout = &out
		err := cmd.Run() // "err != nil" if no line matched with grep
		return strings.TrimRight(out.String(), "\n"), err
	}
	key := string("CONFIG_IPV6_SEG6_LWTUNNEL=y")
	if _, err := grepKey(key, filename); err != nil {
		msg := "Skipped test because it requires SEG6_LWTUNNEL support."
		log.Println(msg)
		t.Skip(msg)
	}
	// Add CONFIG_IPV6_SEG6_HMAC to support seg6_hamc
	// key := string("CONFIG_IPV6_SEG6_HMAC=y")

	return setUpNetlinkTest(t)
}

func setUpNetlinkTestWithKModule(t *testing.T, moduleNames ...string) tearDownNetlinkTest {
	skipUnlessKModuleLoaded(t, moduleNames...)
	return setUpNetlinkTest(t)
}
func setUpNamedNetlinkTestWithKModule(t *testing.T, moduleNames ...string) (string, tearDownNetlinkTest) {
	file, err := ioutil.ReadFile("/proc/modules")
	if err != nil {
		t.Fatal("Failed to open /proc/modules", err)
	}

	foundRequiredMods := make(map[string]bool)
	lines := strings.Split(string(file), "\n")

	for _, name := range moduleNames {
		foundRequiredMods[name] = false
		for _, line := range lines {
			n := strings.Split(line, " ")[0]
			if n == name {
				foundRequiredMods[name] = true
				break
			}
		}
	}

	failed := false
	for _, name := range moduleNames {
		if found, _ := foundRequiredMods[name]; !found {
			t.Logf("Test requires missing kmodule %q.", name)
			failed = true
		}
	}
	if failed {
		t.SkipNow()
	}

	return setUpNamedNetlinkTest(t)
}

func remountSysfs() error {
	if err := unix.Mount("", "/", "none", unix.MS_SLAVE|unix.MS_REC, ""); err != nil {
		return err
	}
	if err := unix.Unmount("/sys", unix.MNT_DETACH); err != nil {
		return err
	}
	return unix.Mount("", "/sys", "sysfs", 0, "")
}

func minKernelRequired(t *testing.T, kernel, major int) {
	t.Helper()

	k, m, err := KernelVersion()
	if err != nil {
		t.Fatal(err)
	}
	if k < kernel || k == kernel && m < major {
		t.Skipf("Host Kernel (%d.%d) does not meet test's minimum required version: (%d.%d)",
			k, m, kernel, major)
	}
}

func KernelVersion() (kernel, major int, err error) {
	uts := unix.Utsname{}
	if err = unix.Uname(&uts); err != nil {
		return
	}

	ba := make([]byte, 0, len(uts.Release))
	for _, b := range uts.Release {
		if b == 0 {
			break
		}
		ba = append(ba, byte(b))
	}
	var rest string
	if n, _ := fmt.Sscanf(string(ba), "%d.%d%s", &kernel, &major, &rest); n < 2 {
		err = fmt.Errorf("can't parse kernel version in %q", string(ba))
	}
	return
}

func TestMain(m *testing.M) {
	nl.EnableErrorMessageReporting = true
	os.Exit(m.Run())
}
